{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Legendre Memory Units in Nengo\n",
    "\n",
    "Legendre Memory Units (LMUs) are\n",
    "a novel recurrent neural network architecture, described in\n",
    "[Voelker, Kajić, and Eliasmith (NeurIPS 2019)][paper].\n",
    "We will not go into much of the underlying details of these methods here;\n",
    "broadly speaking, we can think of an LMU as a recurrent network\n",
    "that does a very good job of representing\n",
    "the temporal information in some input signal.\n",
    "Since most RNN tasks involve computing\n",
    "some function of that temporal information,\n",
    "the better the RNN is at representing the temporal information\n",
    "the better it will be able to perform the task.\n",
    "See the [paper][] for all the details!\n",
    "\n",
    "In this example we will show how an LMU can be used\n",
    "to delay an input signal for some fixed length of time.\n",
    "This is a simple sounding task, but performing an accurate delay\n",
    "requires the network to store the complete history of the input signal\n",
    "across the delay period.\n",
    "So it is a good measure of a network's fundamental temporal storage.\n",
    "\n",
    "[paper]: https://papers.nips.cc/paper/9689-legendre-memory-units-continuous-time-representation-in-recurrent-neural-networks.pdf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "from collections import deque\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "import nengo\n",
    "from nengo.utils.filter_design import cont2discrete"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Our LMU in this example will have two parameters:\n",
    "the length of the time window it is optimized to store,\n",
    "and the number of Legendre polynomials used to represent the signal\n",
    "(using higher order polynomials\n",
    "allows the LMU to represent higher frequency information).\n",
    "\n",
    "The input will be a band-limited white noise signal,\n",
    "which has its own parameters \n",
    "determining the amplitude and frequency of the signal.\n",
    "\n",
    "Feel free to adjust any of these parameters\n",
    "to see what impact they have on the model's performance."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# parameters of LMU\n",
    "theta = 1.0  # length of window (in seconds)\n",
    "order = 8  # number of Legendre polynomials representing window\n",
    "\n",
    "# parameters of input signal\n",
    "freq = 2  # frequency limit\n",
    "rms = 0.30  # amplitude of input (set to keep within [-1, 1])\n",
    "delay = 0.5  # length of time delay network will learn\n",
    "\n",
    "# simulation parameters\n",
    "dt = 0.001  # simulation timestep\n",
    "sim_t = 100  # length of simulation\n",
    "seed = 0  # fixed for deterministic results"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next we need to compute\n",
    "the analytically derived weight matrices used in the LMU.\n",
    "These are determined statically based on\n",
    "the `theta`/`order` parameters from above.\n",
    "It is also possible to optimize these parameters using backpropagation,\n",
    "using a framework such as [Nengo DL](https://www.nengo.ai/nengo-dl)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# compute the A and B matrices according to the LMU's mathematical derivation\n",
    "# (see the paper for details)\n",
    "Q = np.arange(order, dtype=np.float64)\n",
    "R = (2 * Q + 1)[:, None] / theta\n",
    "j, i = np.meshgrid(Q, Q)\n",
    "\n",
    "A = np.where(i < j, -1, (-1.0) ** (i - j + 1)) * R\n",
    "B = (-1.0) ** Q[:, None] * R\n",
    "C = np.ones((1, order))\n",
    "D = np.zeros((1,))\n",
    "\n",
    "A, B, _, _, _ = cont2discrete((A, B, C, D), dt=dt, method=\"zoh\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next we will set up an artificial synapse model\n",
    "to compute an ideal delay\n",
    "(we'll use this to train the model later on).\n",
    "And we can run a simple network containing\n",
    "just our input signal and the ideal delay to see what that looks like."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class IdealDelay(nengo.synapses.Synapse):\n",
    "    def __init__(self, delay):\n",
    "        super().__init__()\n",
    "        self.delay = delay\n",
    "\n",
    "    def make_state(self, *args, **kwargs):\n",
    "        return {}\n",
    "\n",
    "    def make_step(self, shape_in, shape_out, dt, rng, state):\n",
    "        # buffer the input signal based on the delay length\n",
    "        buffer = deque([0] * int(self.delay / dt))\n",
    "\n",
    "        def delay_func(t, x):\n",
    "            buffer.append(x.copy())\n",
    "            return buffer.popleft()\n",
    "\n",
    "        return delay_func\n",
    "\n",
    "    \n",
    "with nengo.Network(seed=seed) as net:\n",
    "    # create the input signal\n",
    "    stim = nengo.Node(\n",
    "        output=nengo.processes.WhiteSignal(\n",
    "            high=freq, period=sim_t, rms=rms, y0=0, seed=seed\n",
    "        )\n",
    "    )\n",
    "    \n",
    "    # probe input signal and an ideally delayed version of input signal\n",
    "    p_stim = nengo.Probe(stim)\n",
    "    p_ideal = nengo.Probe(stim, synapse=IdealDelay(delay))\n",
    "    \n",
    "# run the network and display results\n",
    "with nengo.Simulator(net) as sim:\n",
    "    sim.run(10)\n",
    "    \n",
    "    plt.figure(figsize=(16, 6))\n",
    "    plt.plot(sim.trange(), sim.data[p_stim], label=\"input\")\n",
    "    plt.plot(sim.trange(), sim.data[p_ideal], label=\"ideal\")\n",
    "    plt.legend();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we are ready to build the LMU.\n",
    "The full LMU architecture consists of two components:\n",
    "a linear memory, and a nonlinear hidden state.\n",
    "But the nonlinear hidden state is really only useful\n",
    "when it is optimized using backpropagation (see\n",
    "[this example in NengoDL](https://www.nengo.ai/nengo-dl/examples/lmu.html)).\n",
    "So here we will just build the linear memory component."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with net:\n",
    "    lmu = nengo.Node(size_in=order)\n",
    "    nengo.Connection(stim, lmu, transform=B, synapse=None)\n",
    "    nengo.Connection(lmu, lmu, transform=A, synapse=0)    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "On its own the LMU isn't performing a task,\n",
    "it is just internally representing the input signal.\n",
    "So to get this network to perform a function,\n",
    "we will add an output Ensemble\n",
    "that gets the output of the LMU as input.\n",
    "Then we will train the output weights of that Ensemble\n",
    "using the PES online learning rule.\n",
    "The error signal will be based on the ideally delayed input signal,\n",
    "so the network should learn to compute that same delay."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with net:\n",
    "    ens = nengo.Ensemble(1000, order, neuron_type=nengo.SpikingRectifiedLinear())\n",
    "    nengo.Connection(lmu, ens, synapse=None)\n",
    "\n",
    "    out = nengo.Node(size_in=1)\n",
    "\n",
    "    # we'll use a Node to compute the error signal so that we can shut off\n",
    "    # learning after a while (in order to assess the network's generalization)\n",
    "    err_node = nengo.Node(lambda t, x: x if t < sim_t * 0.8 else 0, size_in=1)\n",
    "    \n",
    "    # the target signal is the ideally delayed version of the input signal,\n",
    "    # which is subtracted from the ensemble's output in order to compute the\n",
    "    # PES error\n",
    "    nengo.Connection(stim, err_node, synapse=IdealDelay(delay), transform=-1)\n",
    "    nengo.Connection(out, err_node, synapse=None)\n",
    "    \n",
    "    learn_conn = nengo.Connection(\n",
    "        ens, out, function=lambda x: 0, learning_rule_type=nengo.PES(2e-4)\n",
    "    )\n",
    "    nengo.Connection(err_node, learn_conn.learning_rule, synapse=None)\n",
    "\n",
    "    p_out = nengo.Probe(out)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, we can run the full model\n",
    "to see it learning to perform the delay task."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with nengo.Simulator(net) as sim:\n",
    "    sim.run(sim_t)\n",
    "\n",
    "# we'll break up the output into multiple plots, just for \n",
    "# display purposes\n",
    "t_per_plot = 10\n",
    "for i in range(sim_t // t_per_plot):\n",
    "    plot_slice = (sim.trange() >= t_per_plot * i) & (\n",
    "        sim.trange() < t_per_plot * (i + 1)\n",
    "    )\n",
    "\n",
    "    plt.figure(figsize=(16, 6))\n",
    "    plt.plot(sim.trange()[plot_slice], sim.data[p_stim][plot_slice], label=\"input\")\n",
    "    plt.plot(sim.trange()[plot_slice], sim.data[p_ideal][plot_slice], label=\"ideal\")\n",
    "    plt.plot(sim.trange()[plot_slice], sim.data[p_out][plot_slice], label=\"output\")\n",
    "    if i * t_per_plot < sim_t * 0.8:\n",
    "        plt.title(\"Learning ON\")\n",
    "    else:\n",
    "        plt.title(\"Learning OFF\")\n",
    "    plt.ylim([-1, 1])\n",
    "    plt.legend();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can see that the network is successfully learning to compute a delay.\n",
    "We could use these same principles to train a network\n",
    "to compute any time-varying function of some input signal,\n",
    "and an LMU will always provide\n",
    "an optimal representation of that input signal.\n",
    "\n",
    "See these other examples for some other applications of LMUs:\n",
    "\n",
    "- [State of the art performance on the psMNIST task using LMUs in NengoDL](\n",
    "https://www.nengo.ai/nengo-dl/examples/lmu.html)\n",
    "<!-- - TODO: loihi example -->\n",
    "\n",
    "As well as [the original paper][paper] for more information on LMUs.\n",
    "\n",
    "[paper]: https://papers.nips.cc/paper/9689-legendre-memory-units-continuous-time-representation-in-recurrent-neural-networks.pdf"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
